/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.sedona.core.spatialPartitioning.quadtree;

import java.io.Serializable;
import java.util.*;
import org.apache.sedona.common.utils.HalfOpenRectangle;
import org.apache.sedona.core.spatialPartitioning.PartitioningUtils;
import org.apache.sedona.core.spatialPartitioning.QuadTreeRTPartitioning;
import org.locationtech.jts.geom.Envelope;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.Point;
import org.locationtech.jts.index.strtree.STRtree;
import scala.Tuple2;

/**
 * The ExtendedQuadTree class uses a modified quad-tree approach for partitioning spatial data, as
 * described in "Simba: Efficient In-Memory Spatial Analytics".
 *
 * <p>In this approach, a global R-tree index is constructed by taking a set of random samples from
 * the dataset and building the R-tree on the master node. For each partition, the distance from the
 * furthest point in the partition to its centroid is calculated. Using the R-tree, the k-nearest
 * neighbors of each centroid are found, and a distance bound is derived for each partition. This
 * bound ensures that the k-nearest neighbors of any point in a partition can be found within a
 * subset of the data identified by a circle range query centered at the centroid with the derived
 * radius. This method guarantees efficient and accurate k-nearest neighbor joins by leveraging both
 * local and global spatial indexing.
 */
public class ExtendedQuadTree<T> extends PartitioningUtils implements Serializable {
  // hold the samples temporarily
  private final transient List<Envelope> samples = new ArrayList<>();

  private final Envelope boundary;
  private final int numPartitions;

  // original partitioning quad tree
  // for non-overlapped partitioning, the original quad tree is used
  private StandardQuadTree<Integer> partitionTree;

  // The expanded boundaries are generated by the quad tree + r-tree
  // for overlapped partitioning, the expanded boundaries are used
  private HashMap<Integer, List<Envelope>> expandedBoundaries;

  // The spatial index for partitioned MBRs
  // for overlapped partitioning, the spatial index is used
  private STRtree spatialExpandedBoundaryIndex;

  private boolean useNonOverlapped = false;

  public HashMap<Integer, List<Envelope>> getExpandedBoundaries() {
    return expandedBoundaries;
  }

  public STRtree getSpatialExpandedBoundaryIndex() {
    return spatialExpandedBoundaryIndex;
  }

  /**
   * Constructor to initialize the partitions list.
   *
   * @param boundary
   * @param numPartitions
   */
  public ExtendedQuadTree(Envelope boundary, int numPartitions) {
    this.boundary = boundary;
    this.numPartitions = numPartitions;
  }

  /**
   * Constructor to initialize the partitions list with non-overlapped boundaries.
   *
   * @param extendedQuadTree
   * @param useNonOverlapped
   */
  public ExtendedQuadTree(ExtendedQuadTree<?> extendedQuadTree, boolean useNonOverlapped) {
    this.boundary = extendedQuadTree.boundary;
    this.numPartitions = extendedQuadTree.numPartitions;
    this.expandedBoundaries = extendedQuadTree.expandedBoundaries;
    this.spatialExpandedBoundaryIndex = extendedQuadTree.spatialExpandedBoundaryIndex;
    this.partitionTree = extendedQuadTree.partitionTree;
    this.useNonOverlapped = useNonOverlapped;
  }

  /**
   * Returns the boundary of the partition zones.
   *
   * @return
   */
  public Envelope getBoundary() {
    return boundary;
  }

  /**
   * Returns the number of partitions.
   *
   * @return
   */
  public int getPartitionNum() {
    return numPartitions;
  }

  /**
   * Insert a new sample into the sample list.
   *
   * @param sample Envelope object to be inserted.
   */
  public void insert(Envelope sample) {
    samples.add(sample);
  }

  /**
   * Check the geometry against the partition zones to find the IDs of overlapping ranges. Note that
   * the geometry can be in multiple ranges because ranges can overlap.
   *
   * @param geometry Geometry object to be placed.
   * @return Iterator of Tuple2 containing partition ID and the corresponding geometry.
   */
  @Override
  public Iterator<Tuple2<Integer, Geometry>> placeObject(Geometry geometry) {
    if (useNonOverlapped) {
      Objects.requireNonNull(geometry, "spatialObject");

      // KNN join uses geometry's centroid to calculate the distance
      final Envelope envelope = geometry.getCentroid().getEnvelopeInternal();

      final List<QuadRectangle> matchedPartitions =
          this.partitionTree.findZones(new QuadRectangle(envelope));

      final Point point = geometry.getCentroid();

      final Set<Tuple2<Integer, Geometry>> result = new HashSet<>();
      for (QuadRectangle rectangle : matchedPartitions) {
        // Ignore null or empty point
        if (point == null || point.isEmpty()) break;

        // For points, make sure to return only one partition
        if (!(new HalfOpenRectangle(rectangle.getEnvelope())).contains(point)) {
          continue;
        }

        result.add(new Tuple2(rectangle.partitionId, geometry));
      }

      // knn join uses the centroid of the geometry
      return result.iterator();
    } else {
      // use the expanded boundaries
      List<Tuple2<Integer, Geometry>> result = new ArrayList<>();
      Envelope objectEnvelope = geometry.getEnvelopeInternal();

      // Query the spatial index for intersecting envelopes
      List<Integer> intersectingIds = spatialExpandedBoundaryIndex.query(objectEnvelope);

      for (Integer partitionId : intersectingIds) {
        result.add(new Tuple2<>(partitionId, geometry));
      }

      return result.iterator();
    }
  }

  /**
   * Check the geometry against the partition zones to find the IDs of overlapping ranges. Only
   * returns the IDs of the overlapping partitions. Note that the geometry can be in multiple ranges
   * because ranges can overlap.
   *
   * @param geometry Geometry object to be checked.
   * @return Set of integers representing the IDs of the overlapping partitions.
   */
  @Override
  public Set<Integer> getKeys(Geometry geometry) {
    if (!useNonOverlapped) {
      // knn join uses the centroid of the geometry
      return partitionTree.getKeys(geometry.getCentroid());
    } else {
      // use the expanded boundaries
      Set<Integer> keys = new HashSet<>();
      Envelope objectEnvelope = geometry.getEnvelopeInternal();

      // Query the spatial index for intersecting envelopes
      List<Integer> intersectingIds = spatialExpandedBoundaryIndex.query(objectEnvelope);

      keys.addAll(intersectingIds);

      return keys;
    }
  }

  @Override
  public List<Envelope> fetchLeafZones() {
    return partitionTree.fetchLeafZones();
  }

  /**
   * Builds the quad-tree partitioning structure and calculates the expanded boundaries for
   * efficient spatial partitioning.
   *
   * <p>This method performs the following steps: 1. Forces the quad-tree to grow to a minimum level
   * to ensure a sufficient number of partitions, which might slightly differ from the specified
   * number. 2. Constructs the quad-tree partitioning using the given samples and boundary,
   * initializing it to the calculated minimum level. 3. Creates the expanded boundaries by building
   * an STR (Sort-Tile-Recursive) tree using the provided number of neighbor samples and sampling
   * probability. 4. Clears the samples to avoid broadcasting them to all nodes involved in
   * partitioning.
   *
   * @param neighborSampleNumber the number of neighbor samples to consider for building the STR
   *     tree.
   */
  public void build(int neighborSampleNumber) throws Exception {
    // Force the quad-tree to grow up to a certain level
    // So the actual num of partitions might be slightly different
    int minLevel = (int) Math.max(Math.log(numPartitions) / Math.log(4), 0);
    QuadTreeRTPartitioning quadTreeRTPartitioning =
        new QuadTreeRTPartitioning(samples, boundary, numPartitions, minLevel);
    partitionTree = quadTreeRTPartitioning.getPartitionTree();

    // Create the expanded boundaries
    quadTreeRTPartitioning.buildSTRTree(samples, neighborSampleNumber);
    expandedBoundaries = quadTreeRTPartitioning.getMbrs();
    spatialExpandedBoundaryIndex = quadTreeRTPartitioning.getMbrSpatialIndex();

    // Make sure not to broadcast all the samples used to build the Quad
    // tree to all nodes which are doing partitioning
    samples.clear();
    partitionTree.dropElements();
  }

  public StandardQuadTree<?> getQuadTree() {
    return partitionTree;
  }
}
